{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification heads"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Any 🤗 SetFit model consists of two parts: a [SentenceTransformer](https://sbert.net/) embedding body and a classification head. \n",
    "\n",
    "This guide will show you:\n",
    "* The built-in logistic regression classification head\n",
    "* The built-in differentiable classification head\n",
    "* The requirements for a custom classification head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logistic Regression classification head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When a new SetFit model is initialized, a [scikit-learn logistic regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) head is chosen by default. This has been shown to be highly effective when applied on top of a finetuned sentence transformer body, and it remains the recommended classification head. Initializing a new SetFit model with a Logistic Regression head is simple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression()"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from setfit import SetFitModel\n",
    "\n",
    "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\")\n",
    "model.model_head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To initialize the Logistic Regression head (or any other head) with additional parameters, then you can use the `head_params` argument on `SetFitModel.from_pretrained()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(max_iter=300, solver='liblinear')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from setfit import SetFitModel\n",
    "\n",
    "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\", head_params={\"solver\": \"liblinear\", \"max_iter\": 300})\n",
    "model.model_head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Differentiable classification head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SetFit also provides [SetFitHead](https://huggingface.co/docs/setfit/main/en/reference/main#setfit.SetFitHead) as an exclusively `torch` classification head. It uses a linear layer to map the embeddings to the class. It can be used by setting the `use_differentiable_head` argument on `SetFitModel.from_pretrained()` to `True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SetFitHead({'in_features': 384, 'out_features': 2, 'temperature': 1.0, 'bias': True, 'device': 'cuda'})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from setfit import SetFitModel\n",
    "\n",
    "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\", use_differentiable_head=True)\n",
    "model.model_head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, this will assume binary classification. To change that, also set the `out_features` via `head_params` to the number of classes that you are using."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SetFitHead({'in_features': 384, 'out_features': 5, 'temperature': 1.0, 'bias': True, 'device': 'cuda'})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from setfit import SetFitModel\n",
    "\n",
    "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\", use_differentiable_head=True, head_params={\"out_features\": 5})\n",
    "model.model_head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip warning={true}>\n",
    "\n",
    "Unlike the default Logistic Regression head, the differentiable classification head only supports integer labels in the following range: `[0, num_classes)`.\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training with a differentiable classification head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the [SetFitHead](https://huggingface.co/docs/setfit/main/en/reference/main#setfit.SetFitHead) unlocks some new [TrainingArguments](https://huggingface.co/docs/setfit/main/en/reference/trainer#setfit.TrainingArguments) that are not used with a sklearn-based head. Note that training with SetFit consists of two phases behind the scenes: **finetuning embeddings** and **training a classification head**. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively. For a lot of these cases, the second value is only used if the classification head is differentiable. For example:\n",
    "\n",
    "* **batch_size**: (`Union[int, Tuple[int, int]]`, defaults to `(16, 2)`) - The second value in the tuple determines the batch size when training the differentiable SetFitHead.\n",
    "* **num_epochs**: (`Union[int, Tuple[int, int]]`, defaults to `(1, 16)`) - The second value in the tuple determines the number of epochs when training the differentiable SetFitHead. In practice, the `num_epochs` is usually larger for training the classification head. There are two reasons for this:\n",
    "\n",
    "    1. This training phase does not train with contrastive pairs, so unlike when finetuning the embedding model, you only get one training sample per labeled training text.\n",
    "    2. This training phase involves training a classifier from scratch, not finetuning an already capable model. We need more training steps for this.\n",
    "* **end_to_end**: (`bool`, defaults to `False`) - If `True`, train the entire model end-to-end during the classifier training phase. Otherwise, freeze the Sentence Transformer body and only train the head.\n",
    "* **body_learning_rate**: (`Union[float, Tuple[float, float]]`, defaults to `(2e-5, 1e-5)`) - The second value in the tuple determines the learning rate of the Sentence Transformer body during the classifier training phase. This is only relevant if `end_to_end` is `True`, as otherwise the Sentence Transformer body is frozen when training the classifier.\n",
    "* **head_learning_rate** (`float`, defaults to `1e-2`) - This value determines the learning rate of the differentiable head during the classifier training phase. It is only used if the differentiable head is used.\n",
    "* **l2_weight** (`float`, *optional*) - Optional l2 weight for both the model body and head, passed to the `AdamW` optimizer in the classifier training phase only if a differentiable head is used.\n",
    "\n",
    "For example, a full training script using a differentiable classification head may look something like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Initializing a new SetFit model\n",
    "model = SetFitModel.from_pretrained(\"BAAI/bge-small-en-v1.5\", use_differentiable_head=True, head_params={\"out_features\": 2})\n",
    "\n",
    "# Preparing the dataset\n",
    "dataset = load_dataset(\"SetFit/sst2\")\n",
    "train_dataset = sample_dataset(dataset[\"train\"], label_column=\"label\", num_samples=32)\n",
    "test_dataset = dataset[\"test\"]\n",
    "\n",
    "# Preparing the training arguments\n",
    "args = TrainingArguments(\n",
    "    batch_size=(32, 16),\n",
    "    num_epochs=(3, 8),\n",
    "    end_to_end=True,\n",
    "    body_learning_rate=(2e-5, 5e-6),\n",
    "    head_learning_rate=2e-3,\n",
    "    l2_weight=0.01,\n",
    ")\n",
    "\n",
    "# Preparing the trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    ")\n",
    "trainer.train()\n",
    "# ***** Running training *****\n",
    "#   Num examples = 66\n",
    "#   Num epochs = 3\n",
    "#   Total optimization steps = 198\n",
    "#   Total train batch size = 3\n",
    "# {'embedding_loss': 0.2204, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}                                                                                 \n",
    "# {'embedding_loss': 0.0058, 'learning_rate': 1.662921348314607e-05, 'epoch': 0.76}                                                                                  \n",
    "# {'embedding_loss': 0.0026, 'learning_rate': 1.101123595505618e-05, 'epoch': 1.52}                                                                                  \n",
    "# {'embedding_loss': 0.0022, 'learning_rate': 5.393258426966292e-06, 'epoch': 2.27}                                                                                  \n",
    "# {'train_runtime': 36.6756, 'train_samples_per_second': 172.758, 'train_steps_per_second': 5.399, 'epoch': 3.0}                                                     \n",
    "# 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [00:30<00:00,  6.45it/s] \n",
    "# The `max_length` is `None`. Using the maximum acceptable length according to the current model body: 512.\n",
    "# Epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.03it/s]\n",
    "\n",
    "# Evaluating\n",
    "metrics = trainer.evaluate(test_dataset)\n",
    "print(metrics)\n",
    "# => {'accuracy': 0.8632619439868204}\n",
    "\n",
    "# Performing inference\n",
    "preds = model.predict([\n",
    "    \"It's a charming and often affecting journey.\",\n",
    "    \"It's slow -- very, very slow.\",\n",
    "    \"A sometimes tedious film.\",\n",
    "])\n",
    "print(preds)\n",
    "# => tensor([1, 0, 0], device='cuda:0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom classification head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alongside the two built-in options, SetFit allows you to specify a custom classification head. There are two forms of supported heads: a custom **differentiable** head or a custom **non-differentiable** head. Both heads must implement the following two methods:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom differentiable head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A custom differentiable head must follow these requirements:\n",
    "\n",
    "* Must subclass `nn.Module`.\n",
    "* A `predict` method: `(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs]` - This method classifies the embeddings. The output must integers in the range of `[0, num_classes)`.\n",
    "* A `predict_proba` method: `(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs, num_classes]` - This method classifies the embeddings into probabilities for each class. For each input, the tensor of size `num_classes` must sum to 1. Applying `torch.argmax(output, dim=-1)` should result in the output for `predict`.\n",
    "* A `get_loss_fn` method: `(self) -> nn.Module` - Returns an initialized loss function, e.g. `torch.nn.CrossEntropyLoss()`.\n",
    "* A `forward` method: `(self, Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]` - Given the output from the Sentence Transformer body, i.e. a dictionary of `'input_ids'`, `'token_type_ids'`, `'attention_mask'`, `'token_embeddings'` and `'sentence_embedding'` keys, return a dictionary with a `'logits'` key and a `torch.Tensor` value with shape `[batch_size, num_classes]`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom non-differentiable head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A custom non-differentiable head must follow these requirements:\n",
    "\n",
    "* A `predict` method: `(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs]` - This method classifies the embeddings. The output must integers in the range of `[0, num_classes)`.\n",
    "* A `predict_proba` method: `(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs, num_classes]` - This method classifies the embeddings into probabilities for each class. For each input, the array of size `num_classes` must sum to 1. Applying `np.argmax(output, dim=-1)` should result in the output for `predict`.\n",
    "* A `fit` method: `(self, np.array with shape [num_inputs, embedding_size], List[Any]) -> None` - This method must take a `numpy` array of embeddings and a list of corresponding labels. The labels need not be integers per se. \n",
    "\n",
    "Many classifiers from sklearn already fit these requirements, such as [`RandomForestClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier), [`MLPClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html#sklearn.neural_network.MLPClassifier), [`KNeighborsClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier), etc.\n",
    "\n",
    "When initializing a SetFit model using your custom (non-)differentiable classification head, it is recommended to use the regular `__init__` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import SetFitModel\n",
    "from sklearn.svm import LinearSVC\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Initializing a new SetFit model\n",
    "model_body = SentenceTransformer(\"BAAI/bge-small-en-v1.5\")\n",
    "model_head = LinearSVC()\n",
    "model = SetFitModel(model_body, model_head)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, training and inference can commence like normal, e.g.:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from setfit import Trainer, TrainingArguments, sample_dataset\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Preparing the dataset\n",
    "dataset = load_dataset(\"SetFit/sst2\")\n",
    "train_dataset = sample_dataset(dataset[\"train\"], label_column=\"label\", num_samples=32)\n",
    "test_dataset = dataset[\"test\"]\n",
    "\n",
    "# Preparing the training arguments\n",
    "args = TrainingArguments(\n",
    "    batch_size=32,\n",
    "    num_epochs=3,\n",
    ")\n",
    "\n",
    "# Preparing the trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    ")\n",
    "trainer.train()\n",
    "\n",
    "# Evaluating\n",
    "metrics = trainer.evaluate(test_dataset)\n",
    "print(metrics)\n",
    "# => {'accuracy': 0.8638110928061504}\n",
    "\n",
    "# Performing inference\n",
    "preds = model.predict([\n",
    "    \"It's a charming and often affecting journey.\",\n",
    "    \"It's slow -- very, very slow.\",\n",
    "    \"A sometimes tedious film.\",\n",
    "])\n",
    "print(preds)\n",
    "# => tensor([1, 0, 0], dtype=torch.int32)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
